# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

#' Internal default value initialization scheme.
#'
#' @param name the name of the variable.
#' @param shape the shape of the array to be generated.
#'
mx.init.internal.default <- function(name, shape, ctx, allow.unknown = FALSE) {
  if (endsWith(name, "bias")) return (mx.nd.zeros(shape))
  if (endsWith(name, "gamma")) return (mx.nd.ones(shape))
  if (endsWith(name, "beta")) return (mx.nd.zeros(shape))
  if (endsWith(name, "moving_mean")) return (mx.nd.zeros(shape))
  if (endsWith(name, "moving_var")) return (mx.nd.ones(shape))
  if (allow.unknown) return(NULL)
  stop(paste("Unkown initialization pattern for ", name))
}

#' Create a initializer that initialize the weight with uniform [-scale, scale]
#'
#' @param scale The scale of uniform distribution
#'
#' @export
mx.init.uniform <- function(scale) {
  function(name, shape, ctx, allow.unknown = FALSE) {
    if (!endsWith(name, "weight")) {
      return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
    }
    return (mx.nd.random.uniform(low = -scale, high = scale, shape = shape))
  }
}

#' Create a initializer that initialize the weight with normal(0, sd)
#'
#' @param sd The standard deviation of normal distribution
#'
#' @export
mx.init.normal <- function(sd) {
  function(name, shape, ctx, allow.unknown = FALSE) {
    if (!endsWith(name, "weight")) {
      return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
    }
    return (mx.nd.random.normal(loc = 0, scale = sd, shape = shape))
  }
}

#' @title Xavier initializer
#'
#' @description Create a initializer which initialize weight with Xavier or
#' similar initialization scheme.
#'
#' @param rnd_type A string of \code{character} indicating the type of
#' distribution from which the weights are initialized.
#' @param factor_type A string of \code{character}.
#' @param magnitude A \code{numeric} number indicating the scale of random
#' number range.
#' @export
mx.init.Xavier <- function(rnd_type = "uniform", factor_type = "avg",
                           magnitude = 3){
  function(name, shape, ctx, allow.unknown = FALSE){
    if (!endsWith(name, "weight")) {
      return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
    }
    
    fan_out <- shape[length(shape)]
    fan_in <- prod(shape[-length(shape)])
    factor_val <- switch(factor_type,
                         "avg" = (fan_in + fan_out) / 2,
                         "in" = fan_in,
                         "out" = fan_out,
                         stop("Not supported factor type. See usage of function mx.init.Xavier"))

    scale <- sqrt(magnitude / factor_val)
    
    if (rnd_type == "uniform"){
      return(mx.nd.random.uniform(low = -scale, high = scale, shape = shape))
    } else if (rnd_type == "gaussian"){
      return(mx.nd.random.normal(loc = 0, scale = scale, shape = shape))
    } else {
      stop("Not supported random type. See usage of function mx.init.Xavier")
    }
  }
}


#' Create initialization of argument  like arg.array
#'
#' @param initializer The initializer.
#' @param shape.array A named list that represents the shape of the weights
#' @param ctx mx.context The context of the weights
#' @param skip.unknown Whether skip the unknown weight types
#' @export
mx.init.create <- function(initializer, shape.array, ctx = NULL, skip.unknown = TRUE) {
  if (length(shape.array) == 0) return(list())
  names <- names(shape.array)
  ret <- lapply(
    seq_along(names),
    function(i) initializer(names[[i]], shape.array[[i]], ctx, allow.unknown = skip.unknown))
  names(ret) <- names
  if (skip.unknown) {
    ret <- mx.util.filter.null(ret)
  }
  return(ret)
}
